{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DhFpDRo_F4By"
      },
      "source": [
        "# License\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at:\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7BQOCQE5tZiU"
      },
      "source": [
        "# Composable Function-preserving Expansions for Transformer Architectures\n",
        "\n",
        "This notebook contains implementations of the six function-preserving transformations of transformer-based models proposed in \"Composable Function-preserving Expansions for Transformer Architectures\". We provide a basic implementation of a generic transformer architecture and show that each transformation is function-preserving, both for individual architectural components and the whole transformer model, as well as for individual transformations and combinations of transformations.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PsIAeZNMs7WV"
      },
      "source": [
        "## Imports\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ps8_MboBu6XP"
      },
      "outputs": [],
      "source": [
        "# On Colab, we recommend a GPU or CPU runtime.\n",
        "\n",
        "import numpy as np\n",
        "import math\n",
        "from typing import Any, Callable, Tuple\n",
        "\n",
        "try:\n",
        "  import jax\n",
        "except ModuleNotFoundError: # Install jax if missing\n",
        "  !pip install --quiet jax\n",
        "  import jax\n",
        "\n",
        "import jax.numpy as jnp\n",
        "from jax import random\n",
        "\n",
        "# Seeding for random operations\n",
        "main_rng = random.PRNGKey(42)\n",
        "\n",
        "try:\n",
        "  import flax\n",
        "except ModuleNotFoundError: # Install flax if missing\n",
        "  !pip install --quiet flax\n",
        "  import flax\n",
        "from flax import linen as nn\n",
        "\n",
        "print(\"Device:\", jax.devices()[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LBODnnK8tBgy"
      },
      "source": [
        "## Name constants"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sFXdCSHSfVb2"
      },
      "outputs": [],
      "source": [
        "NAME_P = \"PosE\"\n",
        "NAME_W_OUT = \"W_OUT\"\n",
        "NAME_MLP_l1 = \"MLP_l1\"\n",
        "NAME_MLP_l2 = \"MLP_l2\"\n",
        "NAME_W_Oe = \"W_Oe\"\n",
        "NAME_W_Q = \"W_Q\"\n",
        "NAME_W_K = \"W_K\"\n",
        "NAME_W_V = \"W_V\"\n",
        "NAME_norm1 = \"norm1\"\n",
        "NAME_norm2 = \"norm2\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eh57ZstYtLpa"
      },
      "source": [
        "## Architectural components"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nx6J7Hzgu6XR"
      },
      "outputs": [],
      "source": [
        "def scaled_dot_product(q, k, v, mask=None):\n",
        "  d_k = q.shape[-1]\n",
        "  attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))\n",
        "  attn_logits = attn_logits / math.sqrt(d_k)\n",
        "  if mask is not None:\n",
        "    attn_logits = jnp.where(mask == 0, -9e15, attn_logits)\n",
        "  attention = nn.softmax(attn_logits, axis=-1)\n",
        "  values = jnp.matmul(attention, v)\n",
        "  return values, attention"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "damIJOUxVcEU"
      },
      "outputs": [],
      "source": [
        "class Attention(nn.Module):\n",
        "  dim_k : int # key/query representation dimension\n",
        "  dim_v : int # attention head representation dimension\n",
        "\n",
        "  def setup(self):\n",
        "    self.W_Q = nn.Dense(self.dim_k,\n",
        "                        kernel_init=nn.initializers.xavier_uniform(),\n",
        "                        use_bias=False,\n",
        "                        name = NAME_W_Q\n",
        "                        )\n",
        "    self.W_K = nn.Dense(self.dim_k,\n",
        "                        kernel_init=nn.initializers.xavier_uniform(),\n",
        "                        use_bias=False,\n",
        "                        name = NAME_W_K\n",
        "                        )\n",
        "    self.W_V = nn.Dense(self.dim_v,\n",
        "                        kernel_init=nn.initializers.xavier_uniform(),\n",
        "                        use_bias=False,\n",
        "                        name = NAME_W_V\n",
        "                        )\n",
        "\n",
        "  def __call__(self, x, mask=None):\n",
        "    Q = self.W_Q(x)\n",
        "    K = self.W_K(x)\n",
        "    V = self.W_V(x)\n",
        "\n",
        "    values, attention = scaled_dot_product(Q, K, V, mask=mask) # Softmax\n",
        "\n",
        "    return values, attention\n",
        "\n",
        "\n",
        "class SimpleMultiHeadAttention(nn.Module):\n",
        "  dim_k : int # key/query representation dimension\n",
        "  dim_v : int # attention head representation dimension\n",
        "  dim_E : int # number of attention heads\n",
        "  dim_h : int # transformer hidden dimension (MHA in/out dimension)\n",
        "\n",
        "  def setup(self):\n",
        "    self.heads = [Attention(dim_k=self.dim_k,\n",
        "                            dim_v=self.dim_v) for _ in range(self.dim_E)]\n",
        "    self.embeds = [nn.Dense(self.dim_h,\n",
        "                            kernel_init=nn.initializers.xavier_uniform(),\n",
        "                            use_bias=False,\n",
        "                            name = NAME_W_Oe + \"_\" + str(i)\n",
        "                            ) for i in range(self.dim_E)]\n",
        "\n",
        "  def __call__(self, x):\n",
        "    return sum([embed(head(x)[0]) for head, embed in zip(self.heads,\n",
        "                                                         self.embeds)])\n",
        "\n",
        "\n",
        "class SimpleMultiLayerPerceptron(nn.Module):\n",
        "  dim_p : int # MLP inner dimension\n",
        "  dim_h : int # transformer hidden dimension (MLP in/out dimension)\n",
        "\n",
        "  def setup(self):\n",
        "    self.layer1 = nn.Dense(self.dim_p,\n",
        "                           kernel_init=nn.initializers.xavier_uniform(),\n",
        "                           name=NAME_MLP_l1)\n",
        "    self.nonlinearity = nn.relu\n",
        "    self.layer2 = nn.Dense(self.dim_h,\n",
        "                           kernel_init=nn.initializers.xavier_uniform(),\n",
        "                           name=NAME_MLP_l2)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    return self.layer2(self.nonlinearity(self.layer1(x)))\n",
        "\n",
        "\n",
        "class SimpleTransformerLayer(nn.Module):\n",
        "  dim_k : int # key/query representation dimension\n",
        "  dim_v : int # attention head representation dimension\n",
        "  dim_E : int # number of attention heads\n",
        "  dim_h : int # transformer hidden dimension\n",
        "  dim_p : int # MLP inner dimension\n",
        "\n",
        "  def setup(self):\n",
        "    self.MHA = SimpleMultiHeadAttention(dim_k=self.dim_k,\n",
        "                                        dim_v=self.dim_v,\n",
        "                                        dim_E=self.dim_E,\n",
        "                                        dim_h=self.dim_h)\n",
        "    self.MLP = SimpleMultiLayerPerceptron(dim_p=self.dim_p,\n",
        "                                          dim_h=self.dim_h)\n",
        "    self.norms = [nn.RMSNorm(name=NAME_norm1), nn.RMSNorm(name=NAME_norm2)]\n",
        "\n",
        "  def __call__(self, x):\n",
        "    # MHA block\n",
        "    MHA_out = self.MHA(self.norms[0](x))\n",
        "    x = x + MHA_out\n",
        "\n",
        "    # MLP block\n",
        "    MLP_out = self.MLP(self.norms[1](x))\n",
        "    x = x + MLP_out\n",
        "\n",
        "    return x\n",
        "\n",
        "\n",
        "class SimplePositionalEncoding(nn.Module):\n",
        "  pos_init : Callable[[Any, Tuple[int], Any], Any]\n",
        "\n",
        "  @nn.compact\n",
        "  def __call__(self, inputs):\n",
        "    pos_shape = (1, inputs.shape[1], inputs.shape[2])\n",
        "    pe = self.param(NAME_P, self.pos_init, pos_shape)\n",
        "    return inputs + pe\n",
        "\n",
        "\n",
        "class SimpleTransformer(nn.Module):\n",
        "  dim_k : int # key/query representation dimension\n",
        "  dim_v : int # attention head representation dimension\n",
        "  dim_E : int # number of attention heads\n",
        "  dim_h : int # transformer hidden dimension\n",
        "  dim_p : int # MLP inner dimension\n",
        "  dim_N : int # number of transformer layers\n",
        "  dim_hout : int # final output dimension\n",
        "\n",
        "  def setup(self):\n",
        "    self.encoding = SimplePositionalEncoding(\n",
        "              pos_init=nn.initializers.normal(stddev=0.02),\n",
        "              name=NAME_P)\n",
        "    self.layers = [SimpleTransformerLayer(dim_k=self.dim_k,\n",
        "                                          dim_v=self.dim_v,\n",
        "                                          dim_E=self.dim_E,\n",
        "                                          dim_h=self.dim_h,\n",
        "                                          dim_p=self.dim_p\n",
        "                                          ) for _ in range(self.dim_N)]\n",
        "    self.outlayer = nn.Dense(self.dim_hout,\n",
        "                             kernel_init=nn.initializers.xavier_uniform(),\n",
        "                             use_bias=False)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    # Check if input needs to be zero padded\n",
        "    if x.shape[-1] \u003c self.dim_h:\n",
        "      padding = [[0, 0] for _ in range(len(x.shape))]\n",
        "      padding[-1][1] = self.dim_h - x.shape[-1]\n",
        "      x = jnp.pad(x, padding, mode=\"constant\", constant_values=0)\n",
        "\n",
        "    # Positional encoding\n",
        "    x = self.encoding(x)\n",
        "\n",
        "    # Transformer layers\n",
        "    for layer in self.layers:\n",
        "      x = layer(x)\n",
        "\n",
        "    # Final output\n",
        "    return self.outlayer(x)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-JTO-CVCtPBS"
      },
      "source": [
        "## Expansion utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QgJXULkg0jYW"
      },
      "outputs": [],
      "source": [
        "def keys_to_string_id(keys):\n",
        "  string_id = [k.key for k in keys]\n",
        "  string_id = '/'.join(string_id)\n",
        "  return string_id\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g-wD3u0a0W8_"
      },
      "outputs": [],
      "source": [
        "def params_to_dict(keys_params):\n",
        "  rtn = {}\n",
        "  for (keys, params) in keys_params:\n",
        "    string_id = keys_to_string_id(keys)\n",
        "    rtn[string_id] = params\n",
        "  return rtn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dGIXsaxP1iNc"
      },
      "outputs": [],
      "source": [
        "def params_pad_to_shape(params_source,\n",
        "                        params_target,\n",
        "                        function_preserving=True):\n",
        "  flat_params_source = jax.tree_util.tree_flatten_with_path(params_source)\n",
        "  flat_params_target = jax.tree_util.tree_flatten_with_path(params_target)\n",
        "\n",
        "  assert len(flat_params_target[0]) \u003e= len(flat_params_source[0])\n",
        "\n",
        "  sid2params_source = params_to_dict(flat_params_source[0])\n",
        "  padded_flat_params = []\n",
        "\n",
        "  for i in range(len(flat_params_target[0])):\n",
        "    string_id =  keys_to_string_id(flat_params_target[0][i][0])\n",
        "    print(string_id, \"shape:\", flat_params_target[0][i][1].shape)\n",
        "\n",
        "    if string_id in sid2params_source:\n",
        "      to_pad_params = sid2params_source[string_id]\n",
        "      print(\"Found in source model, shape:\",\n",
        "            to_pad_params.shape, \"sum:\",  jnp.sum(to_pad_params))\n",
        "      original_shape = to_pad_params.shape\n",
        "      target_shape = flat_params_target[0][i][1].shape\n",
        "\n",
        "      to_pad_shape = []\n",
        "      assert len(target_shape) == len(original_shape)\n",
        "      for i in range(len(target_shape)):\n",
        "        to_pad_shape.append((0, target_shape[i] - original_shape[i]))\n",
        "\n",
        "      print(\"to pad shape:\", to_pad_shape)\n",
        "\n",
        "      zero_init = False\n",
        "      if function_preserving:\n",
        "        # MLP expansion Sec 3.1\n",
        "        if NAME_MLP_l2 in string_id and to_pad_shape[0][1] \u003e 0:\n",
        "          zero_init = True\n",
        "        # Heads expansion Sec 3.2\n",
        "        if NAME_W_Oe in string_id and to_pad_shape[0][1] \u003e 0:\n",
        "          zero_init = True\n",
        "        # Attention expansion Sec 3.4\n",
        "        if NAME_W_K in string_id and to_pad_shape[1][1] \u003e 0:\n",
        "          zero_init = True\n",
        "        # Hidden dimension expansion Sec 3.5\n",
        "        if ((NAME_MLP_l2 in string_id and to_pad_shape[-1][1] \u003e 0) or\n",
        "            (NAME_W_Oe in string_id and to_pad_shape[1][1] \u003e 0) or\n",
        "            (NAME_P in string_id and to_pad_shape[-1][1] \u003e 0)):\n",
        "           zero_init = True\n",
        "\n",
        "      if zero_init:\n",
        "        print(\"0 init\")\n",
        "        constant_values = 0\n",
        "      else:\n",
        "        print(\"No 0 init\")\n",
        "        constant_values = 42 # placeholder\n",
        "\n",
        "      # Key matrix scaling for attention expansion Sec 3.4\n",
        "      if NAME_W_K in string_id and to_pad_shape[1][1] \u003e 0:\n",
        "        to_pad_params = to_pad_params * jnp.sqrt(target_shape[1]/\n",
        "                                                 original_shape[1])\n",
        "\n",
        "      # norm scaling for hidden dimension expansion Sec 3.5\n",
        "      if ((NAME_norm1 in string_id or NAME_norm2 in string_id) and\n",
        "              to_pad_shape[-1][1] \u003e 0):\n",
        "        to_pad_params = to_pad_params * jnp.sqrt(original_shape[0]/\n",
        "                                                 target_shape[-1])\n",
        "\n",
        "      padded_params = jnp.pad(to_pad_params, to_pad_shape,\n",
        "                              'constant', constant_values=constant_values)\n",
        "\n",
        "      print(\"Padded\" ,padded_params.shape, \"sum:\", jnp.sum(padded_params))\n",
        "      padded_flat_params.append(padded_params)\n",
        "    else:\n",
        "      print(\"Not found in source model\")\n",
        "      zero_init = False\n",
        "      if function_preserving:\n",
        "        # Head addition Sec 3.3\n",
        "        if NAME_W_Oe in string_id:\n",
        "          zero_init = True\n",
        "        # Layer addition Sec 3.6\n",
        "        if ((NAME_MLP_l2 in string_id) or\n",
        "            (NAME_W_Oe in string_id)):\n",
        "          zero_init = True\n",
        "\n",
        "      if zero_init:\n",
        "        print(\"0 init\")\n",
        "        padded_flat_params.append(jnp.zeros_like(flat_params_target[0][i][1]))\n",
        "      else:\n",
        "        print(\"No 0 init\")\n",
        "        padded_flat_params.append(flat_params_target[0][i][1])\n",
        "\n",
        "\n",
        "    print(\"----\")\n",
        "  return jax.tree_util.tree_unflatten(flat_params_target[1], padded_flat_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iJYpHX7IsuQi"
      },
      "source": [
        "## MultiLayer Perceptron block\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QIJQ21DeRDye"
      },
      "source": [
        "#### MLP base"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T_HCGm8XRDPC"
      },
      "outputs": [],
      "source": [
        "ORIGINAL_p = 6\n",
        "ORIGINAL_h = 5\n",
        "BATCH = 3\n",
        "SEQUENCE = 2\n",
        "\n",
        "main_rng, X_rng = random.split(main_rng)\n",
        "X_MLP = random.normal(X_rng, (BATCH, SEQUENCE, ORIGINAL_h))\n",
        "\n",
        "mlp = SimpleMultiLayerPerceptron(dim_h=ORIGINAL_h, dim_p=ORIGINAL_p)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "params = mlp.init(init_rng, X_MLP)['params']\n",
        "O_MLP = mlp.apply({'params': params}, X_MLP)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQk7lQrVRUZV"
      },
      "source": [
        "#### MLP expansion (Section 3.1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m0e51-lwRUZg"
      },
      "outputs": [],
      "source": [
        "MOD_p = ORIGINAL_p + 3\n",
        "MOD_mlp = SimpleMultiLayerPerceptron(dim_h=ORIGINAL_h, dim_p=MOD_p)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "MOD_params_init = MOD_mlp.init(init_rng, X_MLP)['params']\n",
        "\n",
        "MOD_params = params_pad_to_shape(params, MOD_params_init)\n",
        "\n",
        "MOD_O_MLP = MOD_mlp.apply({'params': MOD_params}, X_MLP)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O_MLP), \"Expanded MLP:\", jnp.sum(MOD_O_MLP))\n",
        "assert jnp.allclose(MOD_O_MLP, O_MLP, rtol = 1e-3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Eibtca9Psnfi"
      },
      "source": [
        "## MultiHead Attention block"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ry6HS8ZaxhDO"
      },
      "source": [
        "#### MHA base"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6h708J8nxkxg"
      },
      "outputs": [],
      "source": [
        "ORIGINAL_k = 6\n",
        "ORIGINAL_v = 6\n",
        "ORIGINAL_h = 5\n",
        "ORIGINAL_E = 2\n",
        "BATCH = 3\n",
        "SEQUENCE = 2\n",
        "\n",
        "main_rng, I_rng = random.split(main_rng)\n",
        "X_MHA = random.normal(I_rng, (BATCH, SEQUENCE, ORIGINAL_h))\n",
        "\n",
        "mha = SimpleMultiHeadAttention(dim_k=ORIGINAL_k,\n",
        "                                dim_v=ORIGINAL_v,\n",
        "                                dim_h=ORIGINAL_h,\n",
        "                                dim_E=ORIGINAL_E)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "params = mha.init(init_rng, X_MHA)['params']\n",
        "O_MHA = mha.apply({'params': params}, X_MHA)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tINOYkVkwfh3"
      },
      "source": [
        "#### Head addition on MHA (Section 3.2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4EetFiePwp_I"
      },
      "outputs": [],
      "source": [
        "MOD_E = ORIGINAL_E + 3\n",
        "MOD_mha = SimpleMultiHeadAttention(\n",
        "    dim_k=ORIGINAL_k,\n",
        "    dim_v=ORIGINAL_v,\n",
        "    dim_h=ORIGINAL_h,\n",
        "    dim_E=MOD_E)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']\n",
        "\n",
        "MOD_params = params_pad_to_shape(params, MOD_params_init)\n",
        "\n",
        "MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O_MHA), \"Added heads:\", jnp.sum(MOD_O_MHA))\n",
        "assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6cccL0nVjy_7"
      },
      "source": [
        "#### Heads expansion on MHA (Section 3.3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "THCWlFlAjy_8"
      },
      "outputs": [],
      "source": [
        "MOD_v = ORIGINAL_v + 3\n",
        "MOD_mha = SimpleMultiHeadAttention(\n",
        "    dim_k=ORIGINAL_k,\n",
        "    dim_v=MOD_v,\n",
        "    dim_h=ORIGINAL_h,\n",
        "    dim_E=ORIGINAL_E)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']\n",
        "\n",
        "MOD_params = params_pad_to_shape(params, MOD_params_init)\n",
        "\n",
        "MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O_MHA), \"Expanded heads:\", jnp.sum(MOD_O_MHA))\n",
        "assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GVeLODnCrX3r"
      },
      "source": [
        "#### Attention expansion on MHA (Section 3.4)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uBpU6CHOrX3s"
      },
      "outputs": [],
      "source": [
        "MOD_k = ORIGINAL_k + 3\n",
        "MOD_mha = SimpleMultiHeadAttention(\n",
        "    dim_k=MOD_k,\n",
        "    dim_v=ORIGINAL_v,\n",
        "    dim_h=ORIGINAL_h,\n",
        "    dim_E=ORIGINAL_E)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']\n",
        "\n",
        "MOD_params = params_pad_to_shape(params, MOD_params_init)\n",
        "\n",
        "MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O_MHA),\n",
        "      \"Expanded key/query representation:\", jnp.sum(MOD_O_MHA))\n",
        "assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ETmQZ6snjJAh"
      },
      "source": [
        "#### MHA Combination"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fi9tG2VNjIpt"
      },
      "outputs": [],
      "source": [
        "MOD_mha = SimpleMultiHeadAttention(\n",
        "    dim_k=MOD_k,\n",
        "    dim_v=MOD_v,\n",
        "    dim_h=ORIGINAL_h,\n",
        "    dim_E=MOD_E)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']\n",
        "\n",
        "MOD_params = params_pad_to_shape(params, MOD_params_init)\n",
        "\n",
        "MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O_MHA),\n",
        "      \"Expanded key/query representation:\", jnp.sum(MOD_O_MHA))\n",
        "assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VLUsh-9esdzv"
      },
      "source": [
        "## Transformer layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f0m9iy9bk975"
      },
      "source": [
        "#### Layer base"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bm9RM5myk9mr"
      },
      "outputs": [],
      "source": [
        "ORIGINAL_k = 6\n",
        "ORIGINAL_v = 6\n",
        "ORIGINAL_h = 5\n",
        "ORIGINAL_E = 2\n",
        "ORIGINAL_p = 4\n",
        "BATCH = 3\n",
        "SEQUENCE = 2\n",
        "\n",
        "main_rng, I_rng = random.split(main_rng)\n",
        "I_layer = random.normal(I_rng, (BATCH, SEQUENCE, ORIGINAL_h))\n",
        "\n",
        "layer = SimpleTransformerLayer(dim_k=ORIGINAL_k,\n",
        "                               dim_v=ORIGINAL_v,\n",
        "                               dim_h=ORIGINAL_h,\n",
        "                               dim_E=ORIGINAL_E,\n",
        "                               dim_p=ORIGINAL_p)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "params = layer.init(init_rng, I_layer)['params']\n",
        "O_layer = layer.apply({'params': params}, I_layer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0crZNTNN6ZkJ"
      },
      "source": [
        "#### Hidden Dimension Expansion (3.5)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lSEXsyWHmbbz"
      },
      "outputs": [],
      "source": [
        "MOD_h = ORIGINAL_h + 3\n",
        "MOD_I_layer = jnp.pad(I_layer, ((0, 0), (0, 0), (0, MOD_h-ORIGINAL_h)),\n",
        "                      'constant', constant_values=0)\n",
        "O_layer_padded = jnp.pad(O_layer, ((0, 0), (0, 0), (0, MOD_h-ORIGINAL_h)),\n",
        "                         'constant', constant_values=0)\n",
        "MOD_layer = SimpleTransformerLayer(dim_h=MOD_h,\n",
        "                                   dim_E=ORIGINAL_E,\n",
        "                                   dim_p=ORIGINAL_p,\n",
        "                                   dim_k=ORIGINAL_k,\n",
        "                                   dim_v=ORIGINAL_v)\n",
        "MOD_params = MOD_layer.init(init_rng, MOD_I_layer)['params']\n",
        "MOD_params = params_pad_to_shape(params, MOD_params)\n",
        "MOD_O_layer = MOD_layer.apply({'params': MOD_params}, MOD_I_layer)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O_layer),\n",
        "      \"Expanded hidden dimension:\", jnp.sum(MOD_O_layer))\n",
        "assert jnp.allclose(MOD_O_layer, O_layer_padded, rtol = 1e-3)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f8wd2N_HsXQZ"
      },
      "source": [
        "## Full model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4R_FmfGMtAtH"
      },
      "source": [
        "#### Transformer base"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dz2xgBqvtAYn"
      },
      "outputs": [],
      "source": [
        "ORIGINAL_k = 6\n",
        "ORIGINAL_v = 6\n",
        "ORIGINAL_h = 5\n",
        "ORIGINAL_E = 2\n",
        "ORIGINAL_p = 4\n",
        "ORIGINAL_N = 2\n",
        "BATCH = 3\n",
        "SEQUENCE = 2\n",
        "HOUT = 2\n",
        "\n",
        "# Input representation\n",
        "main_rng, I_rng = random.split(main_rng)\n",
        "I = random.normal(I_rng, (BATCH, SEQUENCE, ORIGINAL_h))\n",
        "\n",
        "# MHA block\n",
        "model = SimpleTransformer(dim_k=ORIGINAL_k,\n",
        "                          dim_v=ORIGINAL_v,\n",
        "                          dim_h=ORIGINAL_h,\n",
        "                          dim_E=ORIGINAL_E,\n",
        "                          dim_p=ORIGINAL_p,\n",
        "                          dim_N=ORIGINAL_N,\n",
        "                          dim_hout=HOUT)\n",
        "main_rng, init_rng = random.split(main_rng)\n",
        "params = model.init(init_rng, I)['params']\n",
        "O = model.apply({'params': params}, I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZHZ5i8z0s4ff"
      },
      "source": [
        "#### Layer Addition (3.6)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TUxjOQ88jIBr"
      },
      "outputs": [],
      "source": [
        "MOD_N = ORIGINAL_N + 3\n",
        "MOD_model = SimpleTransformer(dim_h=ORIGINAL_h,\n",
        "                              dim_E=ORIGINAL_E,\n",
        "                              dim_p=ORIGINAL_p,\n",
        "                              dim_k=ORIGINAL_k,\n",
        "                              dim_v=ORIGINAL_v,\n",
        "                              dim_N=MOD_N,\n",
        "                              dim_hout=HOUT)\n",
        "MOD_params = MOD_model.init(init_rng, I)['params']\n",
        "MOD_params = params_pad_to_shape(params, MOD_params)\n",
        "MOD_O = MOD_model.apply({'params': MOD_params}, I)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O), \"Added layers:\", jnp.sum(MOD_O))\n",
        "assert jnp.allclose(MOD_O, O, rtol = 1e-3)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4P-aQIoMMrNA"
      },
      "source": [
        "#### Hidden Dimension Expansion (Sec. 3.5) on full model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SNfx7PKmzFkP"
      },
      "outputs": [],
      "source": [
        "MOD_model = SimpleTransformer(dim_h=MOD_h,\n",
        "                              dim_E=ORIGINAL_E,\n",
        "                              dim_p=ORIGINAL_p,\n",
        "                              dim_k=ORIGINAL_k,\n",
        "                              dim_v=ORIGINAL_v,\n",
        "                              dim_N=ORIGINAL_N,\n",
        "                              dim_hout=HOUT)\n",
        "MOD_params = MOD_model.init(init_rng, I)['params']\n",
        "MOD_params = params_pad_to_shape(params, MOD_params)\n",
        "MOD_O = MOD_model.apply({'params': MOD_params}, I)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O), \"Expanded hidden dimension:\", jnp.sum(MOD_O))\n",
        "assert jnp.allclose(MOD_O, O, rtol = 1e-3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8m_thLhV6fJ1"
      },
      "source": [
        "#### All transformations on full model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aYNSYK1l53xW"
      },
      "outputs": [],
      "source": [
        "MOD_model = SimpleTransformer(dim_h=MOD_h,\n",
        "                              dim_E=MOD_E,\n",
        "                              dim_p=MOD_p,\n",
        "                              dim_k=MOD_k,\n",
        "                              dim_v=MOD_v,\n",
        "                              dim_N=MOD_N,\n",
        "                              dim_hout=HOUT)\n",
        "MOD_params = MOD_model.init(init_rng, I)['params']\n",
        "MOD_params = params_pad_to_shape(params, MOD_params)\n",
        "MOD_O = MOD_model.apply({'params': MOD_params}, I)\n",
        "\n",
        "print(\"Original:\",  jnp.sum(O), \"Expanded hidden dimension:\", jnp.sum(MOD_O))\n",
        "assert jnp.allclose(MOD_O, O, rtol = 1e-3)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1YlVtnBuAoNbj9l7NAuixtS5sV9dHHcO6",
          "timestamp": 1691050723749
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
